# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from graphlearn.python.nn.tf.data.feature_handler import FeatureHandler


class EgoGraph(object):
  """ `EgoGraph` is a basic data structure used to describe a sampled graph.
  It constists of src `Data` and src's neighbors(nodes and edges) `Data`.
  The `EgoGraph` is mainly used to represent subgraphs generated by fixed-size
  neighbor sampling, in which the data can be efficiently organized in dense
  format and the model can be computed using the dense operators.

  Args:
    src: A `Data`/Tensor object used to describe the centric nodes.
    nbr_nodes: A list of `Data`/Tensor instance to describe neighborhood nodes.
    node_schema: A list of tuple to describe the FeatureSpec of src and
      neighbor nodes. Each tuple is formatted with (name, spec), in which `name`
      is node's type, and `spec` is a FeatureSpec object. Be sure that
      `len(node_schema) == len(neighbors) + 1`.
    nbr_nums: A list of number of neighbor nodes per hop.
    nbr_edges: A list of `Data`/Tensor instance to describe neighborhood edges.
    edge_schema: A list of tuple to describe the `FeatureSpec` of neighbor edges.
  """
  def __init__(self,
               src,
               nbr_nodes,
               node_schema,
               nbr_nums,
               nbr_edges=[],
               edge_schema=[],
               **kwargs):
    self._src = src
    self._nbr_nodes = nbr_nodes
    self._node_schema = node_schema
    self._nbr_edges = nbr_edges
    self._edge_schema = edge_schema
    if self._node_schema is not None:
      assert len(self._node_schema) == len(nbr_nums) + 1
    self._nbr_nums = np.array(nbr_nums)

  @property
  def src(self):
    return self._src

  @property
  def node_schema(self):
    return self._node_schema

  @property
  def nbr_nodes(self):
    return self._nbr_nodes

  @property
  def nbr_nums(self):
    return self._nbr_nums

  @property
  def edge_schema(self):
    return self._edge_schema

  @property
  def nbr_edges(self):
    return self._nbr_edges

  def hop_node(self, i):
    """ Get the hop ith neighbors nodes of centric src, where i starts
    from zero. The return value is a tensor with shape
    [batch_size * k_1 *...* k_i, dim], where k_i is the expand neighbor
    count at hop i and dim is the sum of all feature dimensions, which
    may be different due to kinds of vertex types.
    """
    return self._nbr_nodes[i]

  def hop_edge(self, i):
    if len(self._nbr_edges) == 0:
      raise ValueError("No edge data.")
    return self._nbr_edges[i]

  def transform(self, transform_func=None):
    """transforms `EgoGraph`. Default transformation is encoding nodes feature
    to embedding.
    Args:
      transform_func: A function that takes in an `EgoGraph` object and returns
        a transformed version.
    """
    def transform_feat(feat, schema):
      feat_handler = FeatureHandler(schema[0], schema[1])
      return feat_handler.forward(feat)

    if self.node_schema is None:
      return self
    assert len(self.node_schema) == (len(self.nbr_nodes) + 1)

    src = transform_feat(self.src, self.node_schema[0])

    nbr_nodes = []
    for i, nbr in enumerate(self.nbr_nodes):
      nbr_nodes.append(transform_feat(self.nbr_nodes[i], self.node_schema[i + 1]))

    nbr_edges = []
    if self.nbr_edges:
      assert len(self.edge_schema) == len(self.nbr_edges)
      for i, nbr in enumerate(self.nbr_edges):
        nbr_edges.append(transform_feat(self.nbr_edges[i], self.edge_schema[i]))

    return EgoGraph(src, nbr_nodes, None, self.nbr_nums, nbr_edges)


  def __getitem__(self, key):
    return getattr(self, key, None)

  def __setitem__(self, key, value):
    setattr(self, key, value)
